{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "2a2fc3b7-3993-49b0-99b6-6300fa39c805",
   "metadata": {},
   "source": [
    "## Tutorial - 0: Continuous Forward Schrodinger\n",
    "\n",
    "Welcome to our first tutorial on Physics-Informed Neural Networks (PINNs) using `PyTorch Lightning` and our new package, **`pinnstorch`**! This guide aims to introduce you to the integration of deep learning and scientific computing via PINNs, which embed physical laws into neural network training.\n",
    "\n",
    "#### Overview:\n",
    "\n",
    "- **Mesh Generation and Sampling:** Understanding how to create and use meshes for training PINNs, crucial for defining the domain of our problem.\n",
    "- **Neural Network Model with PINNs:** How to build and structure a neural network for physical law integration.\n",
    "- **Physical Laws in Neural Networks:** Detailing the implementation of differential equations within the network using `pde_fn` and `output_fn`.\n",
    "- **Training and Validation:** Utilizing PyTorch Lightning’s Trainer to train our model, including defining the training data, initial conditions, and handling boundary conditions.\n",
    "- **Results Visualization:** Saving and analyzing the outcomes of your PINN models."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4182ae5e-18ac-4d01-95c9-8f201913c3e3",
   "metadata": {},
   "source": [
    "#### Install Libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff823815-5dac-4084-b414-d08e2f98f367",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install pinnstorch\n",
    "!pip install lightning"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15deac4a-feb1-4a53-bc52-45e8f1a9c04d",
   "metadata": {},
   "source": [
    "#### Import Libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6676140f-6407-47bf-853f-095424674dee",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Dict\n",
    "\n",
    "import torch\n",
    "import numpy as np\n",
    "import lightning.pytorch as pl\n",
    "\n",
    "import pinnstorch"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e0aebf83-e628-4838-823f-018c675a6600",
   "metadata": {},
   "source": [
    "### Define Mesh\n",
    "\n",
    "Physics-Informed Neural Networks (PINNs) require a discretized domain (mesh) over which the physical equations are solved. In `pinnstorch`, there are primarily two approaches to define this mesh:\n",
    "\n",
    "\n",
    "- **Defining Time and Spatial Domains Separately:** Here, we use `pinnstorch.data.TimeDomain` and `pinnstorch.data.Interval` for creating 1-D spatial domains. These domains are then used to define a `pinnstorch.data.Mesh`.\n",
    "\n",
    "- **Defining Point Clouds:** This method involves directly utilizing spatio-temporal data (e.g., from experiments or other simulations) to create a mesh using `pinnstorch.data.PointCloud`.\n",
    "\n",
    "Both approaches should ideally yield the same results."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4ca03b7e-2508-4e0a-8768-6e762dde0d84",
   "metadata": {},
   "source": [
    "#### Option 1: Defining Mesh with Separate Time and Spatial Domains\n",
    "\n",
    "We start by defining a function to read and preprocess the solution data from a file. The output should be a dictionary."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f7f04332-675a-41b0-b878-82606e4e0286",
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_data_fn(root_path):\n",
    "    \"\"\"Read and preprocess data from the specified root path.\n",
    "\n",
    "    :param root_path: The root directory containing the data.\n",
    "    :return: Processed data will be used in Mesh class.\n",
    "    \"\"\"\n",
    "\n",
    "    data = pinnstorch.utils.load_data(root_path, \"NLS.mat\")\n",
    "    exact = data[\"uu\"]\n",
    "    exact_u = np.real(exact) # N x T\n",
    "    exact_v = np.imag(exact) # N x T\n",
    "    exact_h = np.sqrt(exact_u**2 + exact_v**2) # N x T\n",
    "    return {\"u\": exact_u, \"v\": exact_v, \"h\": exact_h}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "27606399-b8d2-4dce-99eb-b964298158d0",
   "metadata": {},
   "source": [
    "Now, define the time and spatial domains for mesh generation. The choice of these parameters depends on the specific problem being solved and should be set accordingly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7e16b89-9294-4149-914d-fa0b578cba47",
   "metadata": {},
   "outputs": [],
   "source": [
    "time_domain = pinnstorch.data.TimeDomain(t_interval=[0, 1.57079633], t_points = 201)\n",
    "spatial_domain = pinnstorch.data.Interval(x_interval= [-5, 4.9609375], shape = [256, 1])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "986cb018-49ff-4d0d-a716-0a564d4cbafc",
   "metadata": {},
   "source": [
    "The mesh is then defined using the time and spatial domains along with the read_data_fn function. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b079cfd7-54ab-4b76-927d-de322fa5a67e",
   "metadata": {},
   "outputs": [],
   "source": [
    "mesh = pinnstorch.data.Mesh(root_dir='../data',\n",
    "                            read_data_fn=read_data_fn,\n",
    "                            spatial_domain = spatial_domain,\n",
    "                            time_domain = time_domain)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5a175686-09ec-4c0b-a4a4-877da814c8b9",
   "metadata": {},
   "source": [
    "#### Option 2: Using Point Clouds\n",
    "\n",
    "For scenarios where the spatial and temporal data along with solutions are directly available (e.g., from experimental measurements), we can create a mesh using `pinnstorch.data.PointCloud`.\n",
    "\n",
    "**Note:** It's crucial to format the spatial and temporal domain dimensions correctly. Specifically:\n",
    "\n",
    "- **Spatial Domain:** Each axis should be structured in the shape of $(N \\times 1)$, where $N$ represents the number of spatial points.\n",
    "- **Time Domain:** This should be formatted as $(T \\times 1)$, where $T$ indicates the number of time steps.\n",
    "- **Solution Data:** Each solution variable (e.g., temperature, velocity) should be in the shape of $(N \\times T)$, aligning with the spatial and temporal points."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa987fe1-8bd5-4c23-8231-127346077a73",
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_data_fn(root_path):\n",
    "    \"\"\"Read and preprocess data from the specified root path.\n",
    "\n",
    "    :param root_path: The root directory containing the data.\n",
    "    :return: Processed data will be used in PointCloud class.\n",
    "    \"\"\"\n",
    "\n",
    "    data = pinnstorch.utils.load_data(root_path, \"NLS.mat\")\n",
    "\n",
    "    x = data[\"x\"].T  # N x 1\n",
    "    t = data[\"tt\"].T  # T x 1\n",
    "    \n",
    "    exact = data[\"uu\"]\n",
    "    exact_u = np.real(exact) # N x T\n",
    "    exact_v = np.imag(exact) # N x T\n",
    "    exact_h = np.sqrt(exact_u**2 + exact_v**2) # N x T\n",
    "    \n",
    "    return pinnstorch.data.PointCloudData(\n",
    "            spatial=[x], time=[t], solution={\"u\": exact_u, \"v\": exact_v, \"h\": exact_h}\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "00355618-f848-4809-978e-3d0094f245ed",
   "metadata": {},
   "source": [
    "Now, mesh can be initalize with using the function and the directory to the folder of data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b4cdbae0-009b-436b-9db0-1f562085ab54",
   "metadata": {},
   "outputs": [],
   "source": [
    "mesh = pinnstorch.data.PointCloud(root_dir='./data',\n",
    "                                  read_data_fn=read_data_fn)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "687088a4-1a3f-4ccf-a3f4-6bb226981773",
   "metadata": {},
   "source": [
    "### Define Train datasets"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "203acd43-2376-4380-963e-a42b438d2d35",
   "metadata": {},
   "source": [
    "For solving Schrodinger PDE, we have:\n",
    "- Initial condition\n",
    "- Periodic boundary condition\n",
    "- Collection points for the PDE."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "617c0fce-50a5-4313-9b27-747d3f9ea1fe",
   "metadata": {},
   "source": [
    "#### Initial Condition\n",
    "\n",
    "Let's start with initial condition of the Schrodinger.\n",
    "$$ u(0, x) = 2 \\text{sech}(x) $$\n",
    "$$ v(0, x) = 0 $$\n",
    "\n",
    "For defining initial condition, again we have two options.\n",
    "\n",
    "- **Sample from the data.**\n",
    "- **Defining a function for calculating initial condition.**"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a97e4020-d83b-459a-93f3-e7abb125ec06",
   "metadata": {},
   "source": [
    "##### Set number of samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8be5824c-6c26-423d-8178-e9310b67dfec",
   "metadata": {},
   "outputs": [],
   "source": [
    "N0 = 50"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a22c4c2f-87c2-4eac-96da-5980791f6cc3",
   "metadata": {},
   "source": [
    "##### Option 1: Sample from the data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e230de8-eddd-4c49-97d5-abd06913530f",
   "metadata": {},
   "outputs": [],
   "source": [
    "in_c = pinnstorch.data.InitialCondition(mesh = mesh,\n",
    "                                        num_sample = N0,\n",
    "                                        solution = ['u', 'v'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4f35d57e-6243-4775-ab1c-1d070bf20cf1",
   "metadata": {},
   "source": [
    "##### Option 2: Defining a function for calculating initial condition"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7579313b-2079-46b3-806d-7c81efbc6e26",
   "metadata": {},
   "source": [
    "The input of `initial_fun` should be the same as spatial domain of the problem."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "225da7c1-c211-4dfb-b95f-a6b4ac51df52",
   "metadata": {},
   "outputs": [],
   "source": [
    "def initial_fun(x):\n",
    "    return {'u': 2*1/np.cosh(x), 'v': np.zeros_like(x)}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "142c93c1-65d8-4cf6-a175-98615ff508c2",
   "metadata": {},
   "outputs": [],
   "source": [
    "in_c = pinnstorch.data.InitialCondition(mesh = mesh,\n",
    "                                        num_sample = N0,\n",
    "                                        initial_fun = initial_fun,\n",
    "                                        solution = ['u', 'v'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4d626918-27a8-48e4-888f-703871193d42",
   "metadata": {},
   "source": [
    "The `solution` attribute in `pinnstorch.data.InitialCondition` specifies the solutions (`u` and `v` in our case) to be sampled for initial conditions."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "37b283ff-80a1-4074-8c63-0df265a63925",
   "metadata": {},
   "source": [
    "#### Periodic Boundary Condition"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "42cbb359-8cd3-4bf0-ac27-70921c2e5c6d",
   "metadata": {},
   "source": [
    "The `pinnstorch.data.PeriodicBoundaryCondition` is used to sample periodic points from the upper and lower bounds of the spatial domain (mesh). The `derivative_order` parameter specifies the order of the derivative to be matched at these boundaries. In our case, for the Schrödinger equation, both the function and its first spatial derivative should match at the boundaries, hence `derivative_order = 1`.\n",
    "\n",
    "\n",
    "$$ u(t,-5) = u(t, 5), $$\n",
    "$$ v(t,-5) = v(t, 5), $$ \n",
    "$$ u_x(t,-5) = u_x(t, 5),$$\n",
    "$$ v_x(t,-5) = v_x(t, 5) $$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7002a433-0e32-441d-a60e-6386f8c2c594",
   "metadata": {},
   "outputs": [],
   "source": [
    "N_b = 50\n",
    "pe_b = pinnstorch.data.PeriodicBoundaryCondition(mesh = mesh,\n",
    "                                                 num_sample = 50,\n",
    "                                                 derivative_order = 1,\n",
    "                                                 solution = ['u', 'v'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "17457d3d-3ccd-4b79-b7ab-1975bc9d19a3",
   "metadata": {},
   "source": [
    "#### Mesh Sampler for collection points and solutions\n",
    "\n",
    "In our problem, the partial differential equations (PDEs) governing the dynamics are given by:\n",
    "\n",
    "$$ f_u := u_t + 0.5v_{xx} + v(u^2 +v^2),$$\n",
    "$$ f_v := v_t + 0.5u_{xx} + u(u^2 +v^2) $$\n",
    "\n",
    "To find the solutions to these PDEs using a neural network, we must sample points from the domain at which the network will be trained to satisfy these equations. This sampling process is crucial for training our PINN. We utilize the `pinnstorch.data.MeshSampler` for this purpose, specifying the following:\n",
    "\n",
    "- **Number of Sample Points (N_f):** We choose to sample 20,000 points from the domain. This number is a balance between computational efficiency and the need for a sufficiently dense sampling to capture the dynamics of the PDEs.\n",
    "- **Mesh (mesh):** This parameter defines the spatial-temporal domain from which the points will be sampled.\n",
    "- **Collection Points:** We define `['f_u', 'f_v']` as the targets for our collection points. These are not direct outputs from the neural network but are derived from the network outputs and their derivatives (We will define `pde_fn` function later). The PINN will be trained such that these expressions tend towards zero, aligning with the PDE constraints.\n",
    "\n",
    "Here's the code to implement this sampler:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "492bb3de-453c-49b3-b2e6-add7f3458de7",
   "metadata": {},
   "outputs": [],
   "source": [
    "N_f = 20000\n",
    "me_s = pinnstorch.data.MeshSampler(mesh = mesh,\n",
    "                                   num_sample = N_f,\n",
    "                                   collection_points = ['f_v', 'f_u'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "76ed5d5e-f384-4480-b1e5-49d2b88a87ae",
   "metadata": {},
   "source": [
    "### Define Validation dataset\n",
    "\n",
    "For validation, we sample all points from the mesh to evaluate our model comprehensively. Model will be validated for solutions of `u`, `v`, and `h`.\n",
    "\n",
    "**Note:** If `num_sample` is not specified, the sampler will use the entire mesh for data sampling."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c530a34-deef-4660-849b-02c240c5db11",
   "metadata": {},
   "outputs": [],
   "source": [
    "val_s = pinnstorch.data.MeshSampler(mesh = mesh,\n",
    "                                    solution = ['u', 'v', 'h'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9959e346-1d1a-4d9c-bf38-d23c956d34f8",
   "metadata": {},
   "source": [
    "### Define Neural Networks\n",
    "\n",
    "Here, we try to define a neural network for solving the problem. For defining a neural network, we should set number of layers and the name of the outputs. Also, domain bounds should be defined. The `lb` and `ub` parameters represent the lower and upper bounds of the spatial-temporal domain, helping in normalizing inputs to the network. Therefore, the inputs of this network are `x` and `t`, and the outputs of this network are `u` and `v`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5eeb7181-3dfe-4aa9-b454-61644b1542c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "net = pinnstorch.models.FCN(layers = [2, 100, 100, 100, 100, 2],\n",
    "                            output_names = ['u', 'v'],\n",
    "                            lb=mesh.lb,\n",
    "                            ub=mesh.ub)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "260e697e-f68b-4c17-9445-fe319aa06221",
   "metadata": {},
   "source": [
    "### Define `pde_fn` and `output_fn` functions\n",
    "\n",
    "Now, we define `pde_fn` and `output_fn`.\n",
    "- **`output_fn`:** is applied to the network's output, adding any necessary post-processing computations. For example, in our case, `h(x,t) = u(x,t)**2 + v(x,t)**2`, thus, we define this equation in `output_fn`.\n",
    "- **`pde_fn`:** formulates the PDE constraints, which will be used by the `MeshSampler` to compute the loss at the collection points. "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d5abfe2b-06ef-499f-ab5b-3d37d49c46f0",
   "metadata": {},
   "source": [
    "#### `output_fn` function\n",
    "\n",
    "**Note:** `output_fn` should always have these inputs:\n",
    "- **Outputs:** It is output of the network. In our case, this dictionary should have two output: `u` and `v`.\n",
    "- **Spatial domains:** These are the spatial domain variables. In our case, because our problem has 1-D spatial domain, the input just have `x`. For example, if we had 2-D space, we need another input for that dimention. For example, the inputs from `(outputs, x, t)` will be changed to `(outputs, x, y, t)`.\n",
    "- **Time domin:** The last input of `output_fn` function always should be time."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "397a5e17-b9ea-43c8-a486-d9669a6bfeb8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def output_fn(outputs: Dict[str, torch.Tensor],\n",
    "              x: torch.Tensor,\n",
    "              t: torch.Tensor):\n",
    "    \"\"\"Define `output_fn` function that will be applied to outputs of net.\"\"\"\n",
    "\n",
    "    outputs[\"h\"] = torch.sqrt(outputs[\"u\"] ** 2 + outputs[\"v\"] ** 2)\n",
    "\n",
    "    return outputs"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "db0f43f5-972b-4be0-94da-40feea4ea3c7",
   "metadata": {},
   "source": [
    "#### `pde_fn` function\n",
    "\n",
    "The inputs are similar to `output_fn`. Only if we have extra variables for training (i.g. in inverse problems), we should add input at the end of inputs. For example, `(outputs, x, t)` will be `(outputs, x, t, extra_variable)`. `extra_variable` is always a dictionary."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8763e3c-ebaa-4f59-8a75-515818f379f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def pde_fn(outputs: Dict[str, torch.Tensor],\n",
    "           x: torch.Tensor,\n",
    "           t: torch.Tensor):   \n",
    "    \"\"\"Define the partial differential equations (PDEs).\"\"\"\n",
    "    u_x, u_t = pinnstorch.utils.gradient(outputs[\"u\"], [x, t])\n",
    "    v_x, v_t = pinnstorch.utils.gradient(outputs[\"v\"], [x, t])\n",
    "\n",
    "    u_xx = pinnstorch.utils.gradient(u_x, x)[0]\n",
    "    v_xx = pinnstorch.utils.gradient(v_x, x)[0]\n",
    "\n",
    "    outputs[\"f_u\"] = u_t + 0.5 * v_xx + (outputs[\"u\"] ** 2 + outputs[\"v\"] ** 2) * outputs[\"v\"]\n",
    "    outputs[\"f_v\"] = v_t - 0.5 * u_xx - (outputs[\"u\"] ** 2 + outputs[\"v\"] ** 2) * outputs[\"u\"]\n",
    "\n",
    "    return outputs"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5ed79263-eeff-466d-b5c1-183c0d050329",
   "metadata": {},
   "source": [
    "### Define PINNDataModule and PINNModule\n",
    "\n",
    "To integrate with Lightning, we utilize two specialized modules:\n",
    "\n",
    "- `PINNDataModule` (inherited from `LightningDataModule`) manages data.\n",
    "- `PINNModule` (derived from `LightningModule`) handles the model, compilation, and various enhancements like AMP."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b04abdf5-b97e-433a-8663-a6b2866200d9",
   "metadata": {},
   "source": [
    "#### Define `PINNDataModule`\n",
    "Here, we define collection points, initial condition, and preiodic boundary condition as training datasets, and also, we set validation set. `PINNDataModule` is used for defining training, validation, prediction, and test datasets."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc810afe-fa36-4948-b31f-e5527e54b28b",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_datasets = [me_s, in_c, pe_b]\n",
    "val_dataset = val_s\n",
    "datamodule = pinnstorch.data.PINNDataModule(train_datasets = [me_s, in_c, pe_b],\n",
    "                                            val_dataset = val_dataset,\n",
    "                                            pred_dataset = val_s)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1d170a8a-a55e-4d40-9078-4a2f56b32813",
   "metadata": {},
   "source": [
    "#### Define `PINNModule`\n",
    "\n",
    "`PINNModule` handle several things. Here, we will explore the inputs of this class:\n",
    "- **net:**  The neural network.\n",
    "- **pde_fn:** The function representing the PDE to solve.\n",
    "- **optimizer:**  (Optional) The optimizer for training. The default is Adam\n",
    "- **loss_fn:** (Optional) The loss function to use, either \"sse\" or \"mse\". The default is \"sse\".\n",
    "- **scheduler:** (Optional) Learning rate scheduler. The default is None.\n",
    "- **scaler:** (Optional) Gradient scaler for AMP. The default is `torch.cuda.amp.GradScaler`.\n",
    "- **extra_variables:** (Optional) Extra variables in inverse problems. The default is None.\n",
    "- **output_fn:** (Optional) function to process the model's output. The default is None.\n",
    "- **runge_kutta:** (Optional) Runge-Kutta method for solving PDEs in discrete mode. The default is None.\n",
    "- **cudagraph_compile:** Flag to enable CUDA Graph compilation. It works only with a single GPU. The default is True.\n",
    "- **jit_compile:** (Optional) Flag to enable JIT compilation. The default is True.\n",
    "- **amp:** (Optional) Flag to enable Automatic Mixed Precision (AMP). The default is False.\n",
    "- **inline:** (Optional) Flag to enable inline mode in JIT compilation. The default is False.\n",
    "\n",
    "In this example, we initalize `PINNModule` with defined variables. We set Adam optimizer and try to compile the model with CUDA Graph. The loss function here is Mean Square Error (MSE)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a23d83e-d7b6-468e-8204-74b9a40355a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = pinnstorch.models.PINNModule(net = net,\n",
    "                                     pde_fn = pde_fn,\n",
    "                                     output_fn = output_fn,\n",
    "                                     loss_fn = 'mse')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "730294eb-5dd2-46ee-84d7-4743f8bbff13",
   "metadata": {},
   "source": [
    "### Setting Up the Trainer\n",
    "\n",
    "For training our model, we configure a trainer in PyTorch Lightning. Currently, our setup uses a single GPU, as models compiled with CUDA Graph don't support multiple GPUs yet. For a comprehensive understanding of the trainer's options and functionalities, refer to the [official documentation](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.trainer.trainer.Trainer.html). For example, arguments that you can set are:\n",
    "- **accelerator:**  Supports passing different accelerator types such as `cpu`, `gpu`, `tpu`, `ipu`, `hpu`, `mps`, and `auto`.\n",
    "- **devices:** The devices to use. Can be set to a positive number (int or str), a sequence of device indices (list or str), the value -1 to indicate all available devices should be used, or \"auto\" for automatic selection based on the chosen accelerator. Default: \"auto\".\n",
    "- **max_epochs:** Stop training once this number of epochs is reached.\n",
    "- **max_steps:** Stop training after this number of steps.\n",
    "- ...\n",
    "\n",
    "In our example, we configure the trainer for CPU use, specifying one device:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d63a5cc-a0c4-4de7-ba2b-2b10f5140fc9",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = pl.Trainer(accelerator='cpu', devices=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5107377f-c811-43c8-8534-644a1dbb640d",
   "metadata": {},
   "source": [
    "### Training\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee7c8b03-1706-4ba4-84c5-0d293a178a9d",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.fit(model=model, datamodule=datamodule)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "82b872bd-6120-4c65-8693-a9281a2d2396",
   "metadata": {},
   "source": [
    "### Validation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26c2c9b9-52e9-4a89-aa44-be86fbb27e22",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.validate(model=model, datamodule=datamodule)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5cd209e5-e554-4baf-9c4b-f78174f8147c",
   "metadata": {},
   "source": [
    "### Plotting\n",
    "\n",
    "For plotting, we need predict the results, and then, we should concatenate the results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d1a853d0-edc1-494f-a5ff-0ad21870aeb6",
   "metadata": {},
   "outputs": [],
   "source": [
    "preds_list = trainer.predict(model=model, datamodule=datamodule)\n",
    "preds_dict = pinnstorch.utils.fix_predictions(preds_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07f1631e-0f78-4a62-8833-9ae69bf3ada9",
   "metadata": {},
   "outputs": [],
   "source": [
    "pinnstorch.utils.plot_schrodinger(mesh=mesh,\n",
    "                                  preds=preds_dict,\n",
    "                                  train_datasets=train_datasets,\n",
    "                                  val_dataset=val_dataset,\n",
    "                                  file_name='out')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "torch2stable",
   "language": "python",
   "name": "torch2stable"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
